# src/scripts/pretrain_model.py
import os
import sys
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import logging
import random
import argparse
from datetime import datetime
import numpy as np

# --- Add project root to sys.path and import custom modules ---
try:
    # This assumes the script is run as a module from the project root
    from src.mad_datasets import MADTokenDataset
    from src.models import MADModel
    from src.evaluation.visualize import plot_confusion_matrix, plot_training_history
    from src.evaluation.utils import setup_logging, save_checkpoint, save_config_used, load_checkpoint
    from src.evaluation.metrics_calculator import compute_classification_metrics
    from src.external_libs.loss.SupConLoss import SupConLoss
    from src.early_stopping import EarlyStopping
except ImportError:
    print("Error: Make sure to run this script as a module from the project root.")
    print("Example: python -m src.scripts.pretrain_model")
    sys.exit(1)


def set_seed(seed):
    """Fix all random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


@torch.no_grad()
def evaluate_model(model, dataloader, device, config, class_names, logger, viz_dir, epoch_tag=""):
    """Evaluates the model's classification performance during pre-training."""
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    for (x_img, x_sig), y_true in tqdm(dataloader, desc=f"Evaluating {epoch_tag}", leave=False, ncols=100):
        x_img, x_sig = x_img.to(device).float(), x_sig.to(device).float()
        logits = model(x_img, x_sig, for_supcon=False)
        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(logits, dim=1)
        all_probs.extend(probs.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y_true.cpu().numpy())

    metrics = compute_classification_metrics(
        y_true=all_labels, y_pred=all_preds, y_probs=all_probs,
        num_classes=config['num_classes'], class_names=class_names, logger_instance=logger
    )
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(config['num_classes'])))
    plot_confusion_matrix(cm, class_names, viz_dir, f"cm_{epoch_tag}.png")
    return metrics


def run_pretraining(config):
    """Main function to run the SupCon pre-training process."""
    set_seed(config.get('seed', 42))

    output_dir = config["output_dir"]
    os.makedirs(output_dir, exist_ok=True)
    
    # Setup logger to print to console and save to a file
    logger = setup_logging(os.path.join(output_dir, "pretrain_log.txt"))

    logger.info("--- MAD Model SupCon Pre-training Started ---")
    logger.info(f"Output directory: {output_dir}")
    save_config_used(config, os.path.join(output_dir, "config_pretrain_used.yaml"))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    logger.info("Loading datasets...")
    train_dataset = MADTokenDataset(root_dir=config['dataset_root_dir'], usage="train", num_classes=config['num_classes'])
    eval_dataset = MADTokenDataset(root_dir=config['dataset_root_dir'], usage="val", num_classes=config['num_classes'])
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, drop_last=True, num_workers=config.get('num_workers', 0))
    eval_loader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config.get('num_workers', 0))
    class_names = train_dataset.get_class_names()
    logger.info(f"Training samples: {len(train_dataset)}, Validation samples: {len(eval_dataset)}")

    logger.info("Initializing MAD model for pre-training...")
    model_params = {k: v for k, v in config.items() if k in MADModel.get_param_keys()}
    model = MADModel(**model_params).to(device)

    criterion = SupConLoss(temperature=config['supcon_temperature']).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
    
    early_stopper = EarlyStopping(
        patience=config['early_stopping_patience'],
        monitor_metric_name=config['early_stopping_monitor_metric'],
        mode=config.get('early_stopping_mode', 'max'),
        trace_func=logger.info
    )

    start_epoch, best_metric = 0, -float('inf') if early_stopper.mode == 'max' else float('inf')
    history = {"train_loss": [], "val_accuracy": [], "val_f1": []}

    if config.get('resume_path'):
        logger.info(f"Resuming pre-training from: {config['resume_path']}")
        ckpt = load_checkpoint(config['resume_path'], model, optimizer, device, logger)
        if ckpt:
            start_epoch, best_metric, history = ckpt.get('epoch', 0), ckpt.get('best_metric', best_metric), ckpt.get('history', history)

    logger.info(f"Starting pre-training from epoch {start_epoch + 1} for {config['num_epochs']} epochs.")
    for epoch in range(start_epoch, config['num_epochs']):
        model.train()
        epoch_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Pre-training]", leave=True, ncols=100)
        
        for (x_img, x_sig), y_true in progress_bar:
            x_img, x_sig, y_true = x_img.to(device).float(), x_sig.to(device).float(), y_true.to(device)
            optimizer.zero_grad()
            features = model(x_img, x_sig, for_supcon=True)
            loss = criterion(features.unsqueeze(1), labels=y_true)
            loss.backward()
            if config.get('use_gradient_clipping'):
                nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip_val'])
            optimizer.step()
            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{optimizer.param_groups[0]['lr']:.1e}")
        
        avg_train_loss = epoch_loss / len(train_loader)
        history["train_loss"].append(avg_train_loss)
        logger.info(f"Epoch {epoch+1} - Train Loss (SupCon): {avg_train_loss:.4f}")

        if config['do_eval'] and (epoch + 1) % config['eval_every_n_epochs'] == 0:
            eval_metrics = evaluate_model(model, eval_loader, device, config, class_names, logger, viz_dir=output_dir, epoch_tag=f"epoch_{epoch+1}")
            val_f1_key = f"f1_score_{'macro' if config['num_classes'] > 2 else 'binary'}"
            current_f1 = eval_metrics.get(val_f1_key, 0.0)
            history["val_f1"].append(current_f1)
            history["val_accuracy"].append(eval_metrics.get("accuracy", 0.0))

            metric_to_monitor = eval_metrics.get(early_stopper.monitor_metric_name, 0.0)
            is_best = (early_stopper.mode == 'max' and metric_to_monitor > best_metric) or \
                      (early_stopper.mode == 'min' and metric_to_monitor < best_metric)
            if is_best:
                best_metric = metric_to_monitor
                logger.info(f"⭐ New best metric ({early_stopper.monitor_metric_name}): {best_metric:.4f} at epoch {epoch+1}")
            
            save_checkpoint({
                'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(),
                'best_metric': best_metric, 'history': history
            }, is_best, checkpoint_dir=output_dir, best_filename="model_best_pretrained.pth.tar")

            if (epoch + 1) >= config.get("min_epochs_for_early_stopping", 0):
                early_stopper(metric_to_monitor)
                if early_stopper.early_stop:
                    logger.info("Early stopping triggered.")
                    break

    logger.info("--- Pre-training session finished. ---")
    plot_training_history(history, epoch + 1, output_dir=output_dir, file_name="pretraining_history.png")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Pre-train the MAD model using SupCon loss.")
    parser.add_argument('--data_dir', type=str, default="data/processed", help="Path to the processed data directory.")
    parser.add_argument('--output_dir', type=str, default="pretrained_weights", help="Directory to save weights and config file.")
    parser.add_argument('--resume_path', type=str, default=None, help="Path to a checkpoint to resume pre-training from.")
    args = parser.parse_args()

    pretrain_config = {
        "seed": 42,
        "experiment_tag": "MAD_Pretrain_Default",
        "dataset_root_dir": args.data_dir,
        "output_dir": args.output_dir,
        "resume_path": args.resume_path,
        "num_classes": 5,
        "num_img_patches": 256, "img_patch_flat_dim": 768,
        "num_sig_patches": 2560, "sig_patch_dim": 60,
        "embed_dim": 128,
        "limoe_depth": 1,
        "limoe_heads": 3,
        "limoe_dim_head": 42, # Will be recalculated
        "limoe_num_experts": 1,
        "limoe_ff_mult": 2,
        "limoe_top_k": 1,
        "limoe_dropout": 0.1,
        "realnvp_img_layers": 3,
        "realnvp_sig_layers": 5,
        "proj_head_out_dim": 128,
        "batch_size": 8, "num_epochs": 200, "num_workers": 0,
        "learning_rate": 1e-5, "weight_decay": 0.01,
        "supcon_temperature": 0.5,
        "use_gradient_clipping": True, "gradient_clip_val": 0.5,
        "do_eval": True, "eval_every_n_epochs": 1,
        "do_early_stopping": True, "early_stopping_patience": 15,
        "early_stopping_monitor_metric": "val_f1",
        "early_stopping_mode": "max",
        "min_epochs_for_early_stopping": 30,
    }

    if pretrain_config['embed_dim'] % pretrain_config['limoe_heads'] != 0:
        print(f"Warning: embed_dim ({pretrain_config['embed_dim']}) is not divisible by limoe_heads ({pretrain_config['limoe_heads']}).")
    pretrain_config['limoe_dim_head'] = pretrain_config['embed_dim'] // pretrain_config['limoe_heads']
    
    run_pretraining(config=pretrain_config)